package org.pkl.core.generator

import com.oracle.truffle.api.dsl.GeneratedBy
import com.squareup.javapoet.ClassName
import javax.lang.model.SourceVersion
import javax.annotation.processing.RoundEnvironment
import com.squareup.javapoet.MethodSpec
import com.squareup.javapoet.JavaFile
import com.squareup.javapoet.TypeSpec
import javax.annotation.processing.AbstractProcessor
import javax.lang.model.element.*
import javax.lang.model.type.TypeMirror

/**
 * Generates a subclass of `org.pkl.core.stdlib.registry.ExternalMemberRegistry` 
 * for each stdlib module and a factory to instantiate them.
 * Generated classes are written to `generated/truffle/org/pkl/core/stdlib/registry`.
 * 
 * Inputs:
 * - Generated Truffle node classes for stdlib members.
 *   These classes are located in subpackages of `org.pkl.core.stdlib`
 *   and identified via their `@GeneratedBy` annotations.
 * - `@PklName` annotations on hand-written node classes from which Truffle node classes are generated.
 */
class MemberRegistryGenerator : AbstractProcessor() {
  private val truffleNodeClassSuffix = "NodeGen"
  private val truffleNodeFactorySuffix = "NodesFactory"
  private val stdLibPackageName: String = "org.pkl.core.stdlib"
  private val registryPackageName: String = "$stdLibPackageName.registry"
  private val modulePackageName: String = "org.pkl.core.module"

  private val externalMemberRegistryClassName: ClassName =
    ClassName.get(registryPackageName, "ExternalMemberRegistry")
  private val emptyMemberRegistryClassName: ClassName =
    ClassName.get(registryPackageName, "EmptyMemberRegistry")
  private val memberRegistryFactoryClassName: ClassName =
    ClassName.get(registryPackageName, "MemberRegistryFactory")
  private val moduleKeyClassName: ClassName =
    ClassName.get(modulePackageName, "ModuleKey")
  private val moduleKeysClassName: ClassName =
    ClassName.get(modulePackageName, "ModuleKeys")

  override fun getSupportedAnnotationTypes(): Set<String> = setOf(GeneratedBy::class.java.name)

  override fun getSupportedSourceVersion(): SourceVersion = SourceVersion.RELEASE_11

  override fun process(annotations: Set<TypeElement>, roundEnv: RoundEnvironment): Boolean {
    if (annotations.isEmpty()) return true

    val nodeClassesByPackage = collectNodeClasses(roundEnv)
    generateRegistryClasses(nodeClassesByPackage)
    generateRegistryFactoryClass(nodeClassesByPackage.keys)

    return true
  }

  private fun collectNodeClasses(roundEnv: RoundEnvironment) = roundEnv
    .getElementsAnnotatedWith(GeneratedBy::class.java)
    .asSequence()
    .filterIsInstance<TypeElement>()
    .filter { it.qualifiedName.toString().startsWith(stdLibPackageName) }
    .filter { it.simpleName.toString().endsWith(truffleNodeClassSuffix) }
    .sortedWith(compareBy(
      { if (it.enclosingElement.kind == ElementKind.PACKAGE) "" else it.enclosingElement.simpleName.toString() },
      { it.simpleName.toString() }
    ))
    .groupBy { processingEnv.elementUtils.getPackageOf(it) }

  private fun generateRegistryClasses(nodeClassesByPackage: Map<PackageElement, List<TypeElement>>) {
    for ((pkg, nodeClasses) in nodeClassesByPackage) {
      generateRegistryClass(pkg, nodeClasses)
    }
  }

  private fun generateRegistryClass(pkg: PackageElement, nodeClasses: List<TypeElement>) {
    val pklModuleName = getAnnotatedPklName(pkg) ?: pkg.simpleName.toString()
    val pklModuleNameCapitalized = pklModuleName.capitalize()
    val registryClassName = ClassName.get(registryPackageName, "${pklModuleNameCapitalized}MemberRegistry")

    val registryClass = TypeSpec.classBuilder(registryClassName)
      .addJavadoc("Generated by {@link ${this::class.qualifiedName}}.")
      .addModifiers(Modifier.FINAL)
      .superclass(externalMemberRegistryClassName)
    val registryClassConstructor = MethodSpec.constructorBuilder()

    for (nodeClass in nodeClasses) {
      val enclosingClass = nodeClass.enclosingElement
      val pklClassName = getAnnotatedPklName(enclosingClass)
        ?: enclosingClass.simpleName.toString().removeSuffix(truffleNodeFactorySuffix)
      val pklMemberName = getAnnotatedPklName(nodeClass)
        ?: nodeClass.simpleName.toString().removeSuffix(truffleNodeClassSuffix)
      val pklMemberNameQualified = when (pklClassName) {
        // By convention, the top-level class containing node classes 
        // for *module* members is named `<SimpleModuleName>Nodes`.
        // Example: `BaseNodes` for pkl.base
        pklModuleNameCapitalized ->
          "pkl.$pklModuleName#$pklMemberName"
        else ->
          "pkl.$pklModuleName#$pklClassName.$pklMemberName"
      }
      registryClass.addOriginatingElement(nodeClass)
      registryClassConstructor
        .addStatement("register(\$S, \$T::create)", pklMemberNameQualified, nodeClass)
    }

    registryClass.addMethod(registryClassConstructor.build())
    val javaFile = JavaFile.builder(registryPackageName, registryClass.build()).build()
    javaFile.writeTo(processingEnv.filer)
  }

  private fun generateRegistryFactoryClass(packages: Collection<PackageElement>) {
    val registryFactoryClass = TypeSpec.classBuilder(memberRegistryFactoryClassName)
      .addJavadoc("Generated by {@link ${this::class.qualifiedName}}.")
      .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
    val registryFactoryConstructor = MethodSpec.constructorBuilder()
      .addModifiers(Modifier.PRIVATE)
    registryFactoryClass.addMethod(registryFactoryConstructor.build())
    val registryFactoryGetMethod = MethodSpec.methodBuilder("get")
      .addModifiers(Modifier.PUBLIC, Modifier.STATIC)
      .addParameter(moduleKeyClassName, "moduleKey")
      .returns(externalMemberRegistryClassName)
      .beginControlFlow("if (!\$T.isStdLibModule(moduleKey))", moduleKeysClassName)
      .addStatement("return \$T.INSTANCE", emptyMemberRegistryClassName)
      .endControlFlow()
      .beginControlFlow("switch (moduleKey.getUri().getSchemeSpecificPart())")

    for (pkg in packages) {
      val pklModuleName = getAnnotatedPklName(pkg) ?: pkg.simpleName.toString()
      val pklModuleNameCapitalized = pklModuleName.capitalize()
      val registryClassName = ClassName.get(registryPackageName, "${pklModuleNameCapitalized}MemberRegistry")

      // declare dependency on package-info.java (for `@PklName`)
      registryFactoryClass.addOriginatingElement(pkg)
      registryFactoryGetMethod
        .addCode("case \$S:\n", pklModuleName)
        .addStatement("  return new \$T()", registryClassName)
    }

    registryFactoryGetMethod
      .addCode("default:\n")
      .addStatement("  return \$T.INSTANCE", emptyMemberRegistryClassName)
      .endControlFlow()
    registryFactoryClass.addMethod(registryFactoryGetMethod.build())
    val javaFile = JavaFile.builder(registryPackageName, registryFactoryClass.build()).build()
    javaFile.writeTo(processingEnv.filer)
  }

  private fun getAnnotatedPklName(element: Element): String? {
    for (annotation in element.annotationMirrors) {
      when (annotation.annotationType.asElement().simpleName.toString()) {
        "PklName" ->
          return annotation.elementValues.values.iterator().next().value.toString()
        "GeneratedBy" -> {
          val annotationValue = annotation.elementValues.values.first().value as TypeMirror
          return getAnnotatedPklName(processingEnv.typeUtils.asElement(annotationValue))
        }
      }
    }
    return null
  }

  private fun String.capitalize(): String = replaceFirstChar { it.titlecaseChar() }
}
